import os
import sys
import time
import pickle
import codecs

from scipy import sparse

from gensim.corpora.dictionary import Dictionary
from gensim.models.coherencemodel import CoherenceModel
from TU import compute_TU
from KG_parameter import KG
from LNTM_metrics import perplexity, PrecisionAtR

import numpy as np

from collections import Counter

from cllm_parameter_npmi import CLLM
from cllm_parameter_npmi import readDict, findMax

import Parameter

np.set_printoptions(suppress=True,formatter={'float_kind':'{:.20f}'.format})
import sys
np.set_printoptions(threshold=sys.maxsize)


def training(option, Dt_total, Xt_total, St_total):
    domain = [0,1,2,3,4,5,6]
    # load total vocab on all of the training datasets and the testing datasets
    dict_path_total = os.path.join(option.datadir, "Dict", "Id2word_total.txt")
    total_id2word, total_word2id = readDict(dict_path_total)

    # KG
    total_word_num = len(total_id2word)
    KGraph = sparse.lil_matrix((total_word_num, total_word_num), dtype='float')

    # train each domain and yield the result
    for t, topics_id, doc_term, topic_word, doc_topic, accuracy_svm, proportion in CLLM(domain, len(domain), KGraph, total_id2word, total_word2id, option, Dt_total, Xt_total, St_total): 
        # load vocab
        dict_path_t = os.path.join(option.datadir, "Dict", "Id2word_" + str(domain[t]) + ".txt")
        id2word, word2id = readDict(dict_path_t)

        # read the stopwords file
        stopwordsFilePath = "stopwords.dict"
        file = codecs.open(stopwordsFilePath, 'r', 'utf-8')
        stopwords = [line.strip() for line in file] 
        file.close()

        # load texts
        texts = []
        counter = Counter()
        with open(os.path.join(option.datadir, "Cor", "cor_"+ str(domain[t]) + ".txt"),'r',encoding='utf-8') as f:
            for line in f:
                text = line.strip().split()
                if line:
                    texts.append(text)
                    counter.update(text)

        # build corpus
        corpus = []
        for text in texts:
            corpus.append([(word2id[word],counter[word]) for word in text if word in total_word2id.keys()])


        # build dictionary
        dictionary = Dictionary.from_corpus(corpus, id2word=id2word)

        label2id = {}
        lid = 0
        with open(os.path.join(option.datadir, "LabelDict", "label_dict_"+ str(domain[t]) + ".txt"),'r',encoding='utf-8') as f:
            for line in f:
                l = line.strip()
                label2id[l] = lid
                lid += 1


        # load doc-label
        doc_class = []
        with open(os.path.join(option.datadir, "Label", "label_"+ str(domain[t]) + ".txt"),'r',encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line:
                    doc_class.append(label2id[line])
        
        print("build:ok")

        print('start training ...')
        

        print('t=',t)
        topics = []
        for tids in topics_id:
            topics.append([id2word[tid] for tid in tids])
        
        # calculate NPMI using the implementation of gensim
        c_npmi = CoherenceModel(topics=topics, texts=texts, dictionary=dictionary,coherence='c_npmi', topn=20)
        c_npmi_value = c_npmi.get_coherence()
        # calculate TU
        TU_value = compute_TU(topics)
        # calculate perplexity
        perplexity1 = perplexity(doc_term, topic_word, doc_topic)
        # write results
        parameterNow = [option.lambda_u1, option.lambda_c, option.lambda_b1, option.lambda_b2, option.eta]
        f = open(os.path.join(option.savedir, str(parameterNow) + "_" + str(domain[t]) + ".txt"),'w',encoding='utf-8')
        f.writelines(str(c_npmi_value))
        f.writelines("\n")
        f.writelines(str(TU_value))
        f.writelines("\n")
        f.writelines(str(perplexity1))
        f.writelines("\n")


def testing(option, domain, Dt_total, Xt_total, St_total):
    parameterNow = [option.lambda_u1, option.lambda_c, option.lambda_b1, option.lambda_b2, option.eta]
    
    dict_path_total = os.path.join(option.datadir, "Dict", "Id2word_total.txt")
    total_id2word, total_word2id = readDict(dict_path_total)

    # KG
    KGraph = sparse.load_npz(os.path.join(option.savedir, "KG_" + str(parameterNow) + "_6" + ".npz")).tolil()
    #total_word_num = len(total_id2word)
    #KGraph = sparse.lil_matrix((total_word_num, total_word_num), dtype='float')

    for t, topics_id, doc_term, topic_word, doc_topic, accuracy_svm, proportion in CLLM(domain, 1, KGraph, total_id2word, total_word2id, option, Dt_total, Xt_total, St_total):
        # load vocab
        dict_path_t = os.path.join(option.datadir, "Dict", "Id2word_" + str(domain[t]) + ".txt")
        id2word, word2id = readDict(dict_path_t)

        # read the stopwords file
        stopwordsFilePath = "stopwords.dict"
        file = codecs.open(stopwordsFilePath, 'r', 'utf-8')
        stopwords = [line.strip() for line in file] 
        file.close()

        # load texts
        texts = []
        counter = Counter()
        with open(os.path.join(option.datadir, "Cor", "cor_"+ str(domain[t]) + ".txt"),'r',encoding='utf-8') as f:
            for line in f:
                text = line.strip().split()
                if line:
                    texts.append(text)
                    counter.update(text)

        # build corpus
        corpus = []
        for text in texts:
            corpus.append([(word2id[word],counter[word]) for word in text if word in total_word2id.keys()])


        # build dictionary
        dictionary = Dictionary.from_corpus(corpus, id2word=id2word)
        '''
        label2id = {}
        lid = 0
        with open(os.path.join(option.datadir, "LabelDict", "label_dict_"+ str(domain[t]) + ".txt"),'r',encoding='utf-8') as f:
            for line in f:
                l = line.strip()
                label2id[l] = lid
                lid += 1


        # load doc-label
        doc_class = []
        with open(os.path.join(option.datadir, "Label", "label_"+ str(domain[t]) + ".txt"),'r',encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line:
                    doc_class.append(label2id[line])
        '''
        print("build:ok")

        print('start training ...')
        

        print('t=',t)
        topics = []
        for tids in topics_id:
            topics.append([id2word[tid] for tid in tids])

        print(topics)
        
        c_npmi = CoherenceModel(topics=topics, texts=texts, dictionary=dictionary,coherence='c_npmi', topn=20)
        c_npmi_value = c_npmi.get_coherence()
        TU_value = compute_TU(topics)
        perplexity1 = perplexity(doc_term, topic_word, doc_topic)
        f = open(os.path.join(option.savedir, str(parameterNow) + "_" + str(domain[t]) + ".txt"),'w',encoding='utf-8')
        f.writelines(str(c_npmi_value))
        f.writelines("\n")
        f.writelines(str(TU_value))
        f.writelines("\n")
        f.writelines(str(perplexity1))
        f.writelines("\n")
        f.writelines(str(accuracy_svm))
        f.writelines("\n")


def perform_grid_search(option, Dt_total, Xt_total, St_total):
    # grid search hyper-parameters setting
    lambda_u1_list = [10]
    lamda_c_list = [10]
    lambda_b1_list = [0.1]
    lambda_b2_list = [0.001]
    eta_list = [1]

    parameterGroup = []
    for lambda_u1 in lambda_u1_list:
        for lambda_c in lambda_c_list:
            for lambda_b1 in lambda_b1_list:
                for lambda_b2 in lambda_b2_list:
                    for eta in eta_list:
                        parameterGroup.append([lambda_u1, lamda_c, lambda_b1, lambda_b2, eta])

    for i in range(len(parameterGroup)):
        print(i)
        parameterNow = parameterGroup[i]
        print(parameterNow)

        optionNow = option
        optionNow.lambda_u1 = parameterNow[0]
        optionNow.lamda_c = parameterNow[1]
        optionNow.lambda_b1 = parameterNow[2]
        optionNow.lambda_b2 = parameterNow[3]
        optionNow.eta = parameterNow[4]

        # train CLM with parameterGroup[i]
        train_start_time = time.time()
        training(optionNow, Dt_total, Xt_total, St_total)
        print('training cost = %d' % (time.time()-train_start_time))
        # test CLM
        for j in range(2):
            domain = [7 + j]
            testing(optionNow, domain, Dt_total, Xt_total, St_total)



if __name__=='__main__':

    option = Parameter.parse_argv(sys.argv[1:])

    #domain_total = [0,1,2,3,4,5,6,7,8]
    domain_num = 9
    domain_total = list(range(domain_num))
    Dt_total = [[] for i in range(domain_num)] # word co-occurrence matrices of domains
    Xt_total = [[] for i in range(domain_num)] # sppmi matrices of domains
    St_total = [[] for i in range(domain_num)] # subword matrices of domains

    for index in range(0,len(domain_total)):
        # load Dt & Xt & St matrices
        Dt_total[index] = np.load(os.path.join(option.datadir, "Dt", "X_" + str(domain_total[index]) + ".npy"))
        Xt_total[index] = np.load(os.path.join(option.datadir, "sppmi", "sppmi_" + str(domain_total[index]) + ".npy"))
        St_origin = sparse.lil_matrix(np.load(os.path.join(option.datadir, "subword", "X_Comp_relation_" + str(domain_total[index]) + ".npy")))

        # process St by appling a lower bound for weights
        maxSt = findMax(St_origin)
        print(maxSt)

        threshold_5 = maxSt * option.R_s

        num_word = len(Xt_total[index])
        St_final = sparse.eye(num_word, num_word, dtype='float', format='lil') 

        nnz_indices = St_origin.nonzero()
        countforS = 0
        for r_ind, c_ind in zip(nnz_indices[0], nnz_indices[1]):
            if r_ind == c_ind:
                continue
            if St_origin[r_ind, c_ind] > threshold_5:
                countforS += 1
                St_final[r_ind, c_ind] = St_origin[r_ind, c_ind] / maxSt

        print('count of nnz elements in St =', countforS)
        
        St_total[index] = St_final

    if option.grid_search:
        perform_grid_search(option, Dt_total, Xt_total, St_total)
    else:
        # train CLM with parameterGroup[i]
        train_start_time = time.time()
        training(option, Dt_total, Xt_total, St_total)
        print('training cost = %d' % (time.time()-train_start_time))
        # test CLM
        for j in range(2):
            domain = [7 + j]
            testing(option, domain, Dt_total, Xt_total, St_total)
